package com.xmzs.web.controller;


import com.xmzs.common.chat.config.LocalCache;
import com.xmzs.common.chat.domain.request.SaveMsgRequest;
import com.xmzs.common.chat.entity.chat.Message;
import com.xmzs.common.chat.service.SseService;
import com.xmzs.common.chat.domain.request.ChatRequest;
import com.xmzs.common.chat.utils.TikTokensUtil;
import com.xmzs.common.core.domain.model.LoginUser;
import com.xmzs.common.core.exception.ServiceException;
import com.xmzs.common.core.exception.base.BaseException;
import com.xmzs.common.satoken.utils.LoginHelper;
import com.xmzs.system.domain.vo.SysUserVo;
import com.xmzs.system.service.ISysUserService;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

/**
 * 描述：
 *
 * @author https:www.unfbx.com
 * @date 2023-03-01
 */
@Controller
@Slf4j
@RequiredArgsConstructor
public class ChatController {

    private final SseService sseService;

    private final ISysUserService sysUserService;

    /**
     * 聊天接口
     *
     */
    @CrossOrigin
    @PostMapping("/chat")
    @ResponseBody
    public ResponseBodyEmitter sseChat(@RequestBody @Valid ChatRequest chatRequest, HttpServletResponse response) {
        response.setContentType(MediaType.APPLICATION_OCTET_STREAM_VALUE);
        LoginUser loginUser = LoginHelper.getLoginUser();
        if(loginUser ==null){
          throw new ServiceException("用户未登录！");
        }
        SysUserVo sysUserVo=sysUserService.selectUserById(loginUser.getUserId());
        if(sysUserVo.getTokens()<1){
            throw new ServiceException("余额不足,请联系管理员充值！");
        }
        // 上下文信息
        LinkedList<Message> messages = new LinkedList<>();
        if(chatRequest.getUsingContext()){
            // 获取对话记录
            messages = LocalCache.getUserChatMessages(chatRequest.getConversationId(), chatRequest.getContentNumber());
        }
        // 添加本次消息记录
        Message message = Message.builder().content(chatRequest.getPrompt()).role(Message.Role.USER).build();
        messages.add(message);
        int tokens = TikTokensUtil.tokens(chatRequest.getModel(), new ArrayList<>(messages));
        sysUserService.deductToken(loginUser.getUserId(),tokens);
        return sseService.sseChat(chatRequest);
    }

    @CrossOrigin
    @PostMapping("/saveMsg")
    @ResponseBody
    public void saveMsg(@RequestBody @Valid SaveMsgRequest saveMsgRequest) {
        LoginUser loginUser = LoginHelper.getLoginUser();
        if(loginUser ==null){
            throw new BaseException("用户未登录！");
        }
        int tokens = TikTokensUtil.tokens(saveMsgRequest.getModel(), saveMsgRequest.getMsg());
        // 扣除tokens
        sysUserService.deductToken(loginUser.getUserId(),tokens);
        log.info("保存对话记录id:{},消息:{},消耗tokens:{}",loginUser.getUserId(),saveMsgRequest.getMsg(),tokens);
    }
}
